class Solution {
public:
    int ret = 0;
    int find(TreeNode* root){
         if(root == NULL){
            return 0;
        }
        int l = find(root->left);
        int r = find(root->right);
        int deep = max(l,r)+1;
        
        ret = max(ret,l + r);
        
        return deep;
    }                                 
    int diameterOfBinaryTree(TreeNode* root) {
        find(root);
        
        return ret;
    }
};
